{-|
  hback is a dual n-back memory test based primarily on the work of:
  Jaeggi, Buschkuehl, et al. (2008)
  Improving Fluid Intelligence With Training on Working Memory.
  Proceedings of the National Academy of Sciences of the United States of America, 105(19), 6829-6833

  Any reference in the comments to [Paper] refers to the above work.
-}

{-# OPTIONS -fbang-patterns #-}

module Main where

import System.Exit
import IO
import System.Cmd (system)
import Directory (getDirectoryContents)
import System.Environment (getArgs)
import GHC.Conc (threadDelay)
import System.Posix.Unistd (usleep)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.List (intersperse)
import Text.Printf
import Control.Monad
import Graphics.UI.Gtk hiding (fill)
import Graphics.UI.Gtk.General.General
import Graphics.UI.Gtk.Glade
import Graphics.Rendering.Cairo
import Graphics.Rendering.Cairo.SVG

import Data.IORef
import Random

import Paths_hback

-- ========== Data ==========    

type Visual = (Int, Int)
type Audio  = FilePath

data Prediction = None | TruePositive | FalsePositive | FalseNegative | TrueNegative
                deriving (Eq, Show)
type Level = Int
data Timer = Timer Frac Total
type Frac = Int
type Total = Int
data Game  = Game  { gameLevel     :: Level,
                     gameVisuals   :: [(Visual, Maybe Bool)],
                     gameAudios    :: [(Audio, Maybe Bool)],
                     gameVPreds    :: [Prediction],
                     gameAPreds    :: [Prediction]
                   } deriving Show
data State = State { stateGames    :: [Game],
                     stateTimer    :: Timer,
                     stateGUI      :: GUI,
                     totalGames    :: Int,
                     statePaused   :: Bool
                   }
data GUI   = GUI   { guiWindow     :: Window,
                     guiLevelLabel :: Label,
                     guiScoreLabel :: Label,
                     guiDrawArea   :: DrawingArea,
                     guiVButton    :: ToggleButton,
                     guiAButton    :: ToggleButton
                   }

turnOffLogging = False

-- |blocksize + n determines how many iterations each "game" takes
--  20 + n is what [Paper] used
blockSize     = 20
-- |timerFrac * timerFreq = 500ms shows stimuli + 2.5s pause ~ 3s per iteration (from [Paper])
timerFrac     = 5     -- [0..5] = 6 loops
timerFreq     = 500   -- 500 ms
-- |totalNumGames * (blockSize + n) * 3s / 60 =~ Total Gametime
--  Based on [Paper], default is 20 to give about 20 minutes of memory training;
totalNumGames = 20
-- |initLevel determines what N-Back level the game begins with; defaults to 1
initLevel     = 1

gameTurn :: Game -> Int
gameTurn g = length $ gameVPreds g
           
newTimer :: Timer
newTimer = Timer 0 timerFrac

imageList :: IO [Visual]
imageList = return $ remove (1,1) [(a,b) | a <- [0..2], b <- [0..2]]

soundList :: IO [Audio]
soundList = do
  d <- getDataDir
  let dir = d ++ "/sounds/"
  l <- getDirectoryContents dir
  return $ map (dir ++)  $ filter (\f -> tail f == ".wav") l

predictionToInt :: Prediction -> Int
predictionToInt p = case (lookup p preds) of
                      Just n -> n
                      Nothing -> error "Something went terribly wrong (predictionToInt)"
    where
      preds = zip [None, TruePositive, FalsePositive, FalseNegative, TrueNegative] [0..]

-- ========== Scoring ==========

addPredictions :: Game -> Prediction -> Prediction -> Game
addPredictions (Game l v a vp ap) vp' ap' = Game l v a (vp ++ [vp']) (ap ++ [ap'])

realPredictions :: [Prediction] -> [Prediction]
realPredictions = filter (/= None)
                                                 
-- |gamescore vs as takes all visual and audio predictions for a specific and calculates
--  a total score; for now naive score = (TruePositive + TrueNegative) / Total
gameScore :: [Prediction] -> [Prediction] -> Double
gameScore v' a' = num / den
    where
      v = realPredictions v'
      a = realPredictions a'
      s xs = fromIntegral (length (filter (\x -> x == TruePositive || x == TrueNegative) xs)) :: Double
      num = (s v) + (s a)
      den = fromIntegral (2 * length v) :: Double

-- |chooseNextLevel old vPredictions aPredictions returns the next game level
--  based on performance on the previous game; Same as protocol in [Paper]
chooseNextLevel :: Int -> [Prediction] -> [Prediction] -> Int
chooseNextLevel n v a
    | m1 < 3 && m2 < 3 = inc n
    | m1 + m2 > 5      = max 1 $ dec n
    | otherwise        = n
    where
      m1 = miss v
      m2 = miss a
      miss xs = length $ filter (\x -> x == FalseNegative || x == FalsePositive) xs

-- |score trueValue guessValue returns the appropriate logical prediction
score :: Bool -> Bool -> Prediction
score val ans
    | val && ans     = TruePositive
    | not val && ans   = FalsePositive
    | val && not ans   = FalseNegative
    | not val && not ans = TrueNegative


-- ========== Main ==========
            
printUsage :: IO ()
printUsage = putStrLn "hback b n\n b is the number of tests [default=20]\n n determines the starting n-back test [default=1]"

main = do
  args <- getArgs
  printUsage
  (!totalNumGames', !initLevel') <- case args of
                                     []       -> return (totalNumGames, initLevel)
                                     (a:[])   -> return (read a :: Int, initLevel)
                                     (a:b:[]) -> return ((read a :: Int), (read b :: Int))
  initGUI
  gFile <- getDataFileName "hback.glade"
  windowXmlM <- xmlNew gFile
  let windowXml = case windowXmlM of
                    (Just windowXml) -> windowXml
                    Nothing -> error "Can't find the glade file \"hback.glade\" in the current directory"
  window <- xmlGetWidget windowXml castToWindow "hback"
  onDestroy window mainQuit

  label <- xmlGetWidget windowXml castToLabel "testLabel"
  scLabel <- xmlGetWidget windowXml castToLabel "scoreLabel"
  img <- xmlGetWidget windowXml castToDrawingArea "drawArea"
  visualBtn <- xmlGetWidget windowXml castToToggleButton "visualBtn"
  audioBtn  <- xmlGetWidget windowXml castToToggleButton "audioBtn"

  stateRef <- newIORef $ State [] newTimer
                           (GUI window label scLabel img visualBtn audioBtn)
                           totalNumGames' False
  onKeyPress window (processEvent stateRef)
             
  widgetShowAll window
  logInitGame
  startNewGame stateRef initLevel'
  mainGUI

startNewGame :: IORef State -> Level -> IO ()
startNewGame stateRef level = do
  imgList <- imageList
  sndList <- soundList
  preds <- shuffledPredictions level
  visuals <- matchStim imgList level (map fst preds) []
  audios  <- matchStim sndList level (map snd preds) []
  let game = Game level visuals audios [] []
  (State games _ gui tL p) <- readIORef stateRef
  writeIORef stateRef $ State (game:games) newTimer gui tL p
  labelSetText (guiLevelLabel gui) $ show level ++ "-Back Test"
  tmhandle <- timeoutAdd (timerInit stateRef) 500
  return ()

makePredictions = take 2 (repeat (Just True,  Just True))  ++
                  take 4 (repeat (Just True,  Just False)) ++
                  take 4 (repeat (Just False, Just True))  ++
                  take (blockSize - 10) (repeat (Just False, Just False))

shuffledPredictions level = do
  let preds = makePredictions
  rands <- getRandomDecList ((length preds) - 1)
  return $ take level (repeat (Nothing, Nothing)) ++ shuffle preds rands

matchStim :: Ord a => [a] -> Int -> [Maybe Bool] -> [(a, Maybe Bool)] -> IO [(a, Maybe Bool)]
matchStim _ _ [] acc = return $ reverse acc
matchStim orig level (p:ps) acc = do
  e <- case p of
        Nothing -> do
               e' <- randomElem orig
               return (e', Nothing)
        Just True -> do
               let e = fst $ head $ drop (dec level) acc
               return (e, Just True)
        Just False -> do
               let e = fst $ head $ drop (dec level) acc
               e' <- randomElem (remove e orig)
               return (e', Just False)
  matchStim orig level ps (e : acc)

endGame :: IORef State -> IO ()
endGame stateRef = do
  state <- readIORef stateRef
  putStrLn "Game finished"
  sequence_ $ map (\(Game level _ _ vp ap) -> putStrLn ("Level " ++ show level ++ " : " ++ show (gameScore vp ap)))
                  $ reverse $ stateGames state
  mainQuit
  exitWith ExitSuccess
           
-- ========== Timers and Events ==========

timerInit :: IORef State -> IO Bool
timerInit stateRef = do
  state <- readIORef stateRef
  timerInit' stateRef state
    where
      timerInit' :: IORef State -> State -> IO Bool
      timerInit' stateRef state@(State _ tm@(Timer t tt) gui _ _)
          | statePaused state =                       -- game is paused
                       return True
          | t == 0  = do
                       renderImage (guiDrawArea gui) renderNewGame
                       stateTick stateRef
                       return True
          | t == tt = do
                       tmhandle <- timeoutAdd (timer stateRef) timerFreq
                       stateTick stateRef
                       return False
          | otherwise = do
                       stateTick stateRef
                       return True
              
timer :: IORef State -> IO Bool
timer stateRef = do s <- readIORef stateRef
                    timer' stateRef s
timer' :: IORef State -> State -> IO Bool
timer' stateRef state@(State games@(game:prevGames) tm@(Timer t tt) gui total paused)
    | statePaused state =                        -- game is paused
        return True
    | turn >= blockSize + gameLevel game = do    -- current game finished
        if (length games >= totalGames state)
           then do
             logGame stateRef
             endGame stateRef
           else do
             logGame stateRef
             startNewGame stateRef (chooseNextLevel (gameLevel game) (gameVPreds game) (gameAPreds game))
        return False
    | otherwise = do
        let (vZ, vB) = gameVisuals game !! turn
        let (aZ, aB) = gameAudios  game !! turn
        case t of
          0 -> do
            renderImage (guiDrawArea gui) $ renderRect vZ
            playSound aZ
            toggleButtonSetActive (guiVButton gui) False
            toggleButtonSetActive (guiAButton gui) False
          1 -> do
            renderImage (guiDrawArea gui) renderBlank
          _ -> when (t == tt)
                   (do
                     (vs', as') <- case (vB, aB) of
                                      (Just vB', Just aB') -> do
                                        b1 <- toggleButtonGetActive (guiVButton gui)
                                        b2 <- toggleButtonGetActive (guiAButton gui)
                                        return ((score vB' b1), (score aB' b2))
                                      _  -> return (None, None)
                     writeIORef stateRef $ State (addPredictions game vs' as' : prevGames) tm gui total paused)
        stateTick stateRef
        return True
    where
      turn = gameTurn game

stateTick :: IORef State -> IO ()
stateTick stateRef = do
  (State g' t' gui' tg' p') <- readIORef stateRef
  writeIORef stateRef $ State g' (tick t') gui' tg' p'

-- |processEvent stateRef event handles key events
--  (toggling ToggleButtons with arrows and pause with 'p')
processEvent :: IORef State -> Event -> IO Bool
processEvent stateRef (Key {eventKeyName = keyName, eventModifier = evModifier, eventKeyChar = char}) = do
  state@(State g t gui tt p) <- readIORef stateRef
  case char of
    Just 'p' -> do
              case p of
                True  -> renderImage (guiDrawArea gui) renderBlank
                False -> renderImage (guiDrawArea gui) renderPause
              writeIORef stateRef $ State g t gui tt $ not p
              return True
    Just 'l' -> do
              flipToggle $ guiAButton gui
              return True
    Just 'a' -> do
              flipToggle $ guiVButton gui
              return True
    _        -> return False
    where
      flipToggle btn = do
              p <- toggleButtonGetActive btn
              toggleButtonSetActive btn (not p)
processEvent _ _ = return False
         
tick :: Timer -> Timer
tick (Timer t tt)
    | t' > tt   = Timer 0 tt
    | otherwise = Timer t' tt
    where t' = inc t

-- ========== Rendering ==========

renderNewGame :: Int -> Int -> Render ()
renderNewGame w' h' = do
  setSourceRGB 0 0 0
  paint               
  setSourceRGB 1 1 1
  setFontSize 30
  moveTo (w/2 - 70) (h/4)
  showText "Ready?"
  setFontSize 20
  moveTo (w/2 - 130) (h/6 * 3)
  showText "LeftArrow    -> Sound"
  moveTo (w/2 - 130) (h/6 * 4)
  showText "RightArrow -> Graphic"
  moveTo (w/2 - 130) (h/6 * 5)
  showText "      'p'         ->  Pause"
    where w = fromIntegral w' :: Double
          h = fromIntegral h' :: Double

renderPause :: Int -> Int -> Render ()
renderPause wU hU = do
  svgRenderFromString s
      where
        s = (printf "<svg width=\"%d\" height=\"%d\">" w h)
            ++ (printf "<rect style=\"fill:black\" opacity=\"0.7\" width=\"%d\" height=\"%d\" x=\"0\" y=\"0\" />" w h)
            ++ (printf "<text x=\"%d\" y=\"%d\"" ((w `div` 2) - (c*13)) ((h `div` 2) + c))
            ++ (printf " font-family=\"Verdana\" font-size=\"%d\" fill=\"orange\" >" (c * 6))
            ++ "[Paused]"
            ++ "</text></svg>"
        w = min wU hU
        h = w                   -- make sure w and h make square
        c = (w `div` 150) * 5   -- a multiplier when window gets resized

renderBlank :: Int -> Int -> Render ()
renderBlank w h = renderCross w h Nothing

renderRect :: Visual -> Int -> Int -> Render ()
renderRect (x,y) w h = renderCross w h $ Just (x,y)

renderCross wU hU m =
  svgRenderFromString s
      where
        s = (printf "<svg width=\"%d\" height=\"%d\">" w h)
            ++ (printf "<rect style=\"fill:black\" width=\"%d\" height=\"%d\" x=\"0\" y=\"0\" />" w h)
            ++ (printf "<rect style=\"fill:white\" width=\"%d\" height=\"%d\" x=\"%d\" y=\"%d\" />" marg line x2 y2)
            ++ squareString
            ++ (printf "<rect style=\"fill:white\" width=\"%d\" height=\"%d\" x=\"%d\" y=\"%d\" /></svg>" line marg y2 x2)
        w = min wU hU
        h = w              -- make sure w and h make square
        sq = w `div` 3
        marg = sq `div` 10
        line = marg * 8
        x2 = sq + (sq `div` 2) - (marg `div` 2)
        y2 = sq + marg
        squareString = case m of
                         Nothing     -> ""
                         Just (x, y) -> let x1 = marg + (sq * x) in
                                       let y1 = marg + (sq * y) in
                                         printf ("<rect style=\"fill:white\" width=\"%d\" height=\"%d\" "
                                                ++ "x=\"%d\" y=\"%d\" />")
                                                line line x1 y1

renderImage :: DrawingArea -> (Int -> Int -> Render ()) -> IO ()
renderImage drawArea img = do
  (w,h) <- widgetGetSize drawArea
  drawin <- widgetGetDrawWindow drawArea
  renderWithDrawable drawin $ img w h
  return ()

playSound :: Audio -> IO ()
playSound f = do
  system $ "mplayer " ++ f ++ "> /dev/null &"
  return ()

-- ========== Utils ==========

logInitGame :: IO ()
logInitGame = do
  unless (turnOffLogging)
         (do 
           t <- getPOSIXTime
           bracket (openFile "user_score_history.db" AppendMode) hClose
                   (\h -> hPrintf h "%s\n" $ show t))

logGame :: IORef State -> IO ()
logGame stateRef = do
  unless (turnOffLogging)
         (do 
           state <- readIORef stateRef
           let game = head $ stateGames state
           bracket (openFile "user_score_history.db" AppendMode) hClose
                   (\h -> do hPrintf h "Level %d\n%s\n%s\n" (gameLevel game)
                                      (concat (intersperse " " (map (show . predictionToInt) (gameVPreds game))))
                                      (concat (intersperse " " (map (show . predictionToInt) (gameAPreds game))))))

inc :: Int -> Int
inc = (1+)

dec :: Int -> Int
dec n = n - 1

-- naive list shuffle
-- shuffle elems choices. where: length choices == length elems - 1
shuffle :: [a] -> [Int] -> [a]
shuffle []  _ = []
shuffle [e] [] = [e]
shuffle elems (x:xs)
    | x >= (length elems) = error "shuffle: index too large"
    | otherwise = let (a,(e:rest)) = splitAt x elems
                  in e : shuffle (a ++ rest) xs

getRandomDecList :: Int -> IO [Int]
getRandomDecList 0 = return []
getRandomDecList n = do
  r <- getStdRandom $ randomR (0, n)
  rs <- getRandomDecList $ dec n
  return $ r : rs

randomElem :: [a] -> IO a
randomElem lst = do
  i <- getStdRandom (randomR (0, dec (length lst)))
  return $ lst !! i

remove :: Ord a => a -> [a] -> [a]
remove a = filter (/= a)
